import os
import scipy
import scipy.io
import numpy as np
from PIL import Image
import torch
from torchvision import transforms, datasets
from torch.utils.data import DataLoader, Dataset
import numpy as np
import pickle

# ImageNet type datasets, i.e., which support loading with ImageFolder
def imagenette(datadir="./datasets/imagenet_subsets/imagenette2/", batch_size=128, mode="org", size=224, normalize=False, norm_layer=None, workers=4, distributed=False, **kwargs):
    # mode: base | org
    
    if norm_layer is None:
        if normalize:
            norm_layer = transforms.Normalize(mean=[0.4648, 0.4543, 0.4247], std=[0.2785, 0.2735, 0.2944])
        else:
            norm_layer = transforms.Normalize(mean=[0., 0., 0.], std=[1., 1., 1.])
    
    transform_train = transforms.Compose([transforms.RandomResizedCrop(size),
                       transforms.RandomHorizontalFlip(),
                       transforms.ToTensor(),
                       norm_layer
                       ])
    transform_test = transforms.Compose([transforms.Resize(int(1.14*size)),
                      transforms.CenterCrop(size),
                      transforms.ToTensor(), 
                      norm_layer])
    
    if mode == "org":
        None
    elif mode == "base":
        transform_train = transform_test
    else:
        raise ValueError(f"{mode} mode not supported")
        
    trainset = datasets.ImageFolder(
        os.path.join(datadir, "train"), 
        transform=transform_train)
    testset = datasets.ImageFolder(
        os.path.join(datadir, "val"), 
        transform=transform_test)
    
    train_sampler, test_sampler = None, None
    if distributed:
        print("Using DistributedSampler")
        train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(testset)
        
    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=workers, pin_memory=True)
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, sampler=test_sampler, num_workers=workers, pin_memory=True)
    
    return train_loader, train_sampler, test_loader, test_sampler, None, None, transform_train


# cifar10
def cifar10(datadir="./datasets/", batch_size=128, mode="org", size=32, normalize=False, norm_layer=None, workers=4, distributed=False, **kwargs):
    # mode: base | org
    if norm_layer is None:
        if normalize:
            norm_layer = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
        else:
            norm_layer = transforms.Normalize(mean=[0., 0., 0.], std=[1., 1., 1.])
    
    trtrain = [transforms.RandomCrop(size, padding=4), transforms.RandomHorizontalFlip(), 
          transforms.ToTensor(), norm_layer]
    if size != 32:
        trtrain = [transforms.Resize(size)] + trtrain
    transform_train = transforms.Compose(trtrain)
    trval = [transforms.ToTensor(), norm_layer]
    if size != 32:
        trval = [transforms.Resize(size)] + trval
    transform_test = transforms.Compose(trval)

    if mode == "org":
        None
    elif mode == "base":
        transform_train = transform_test
    else:
        raise ValueError(f"{mode} mode not supported")
        
    trainset = datasets.CIFAR10(
            root=os.path.join(datadir, "cifar10"),
            train=True,
            download=True,
            transform=transform_train
        )
    testset = datasets.CIFAR10(
            root=os.path.join(datadir, "cifar10"),
            train=False,
            download=True,
            transform=transform_test,
        )
    
    train_sampler, test_sampler = None, None
    if distributed:
        print("Using DistributedSampler")
        train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(testset)
        
    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=workers, pin_memory=True)
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=(test_sampler is None), sampler=test_sampler, num_workers=workers, pin_memory=True)
    
    return train_loader, train_sampler, test_loader, test_sampler, None, None, transform_train


class custom_loader_from_image_list(torch.utils.data.DataLoader):
    def __init__(self, f, label_extractor, classes=None, training_images=None, transform=None):
        # provide path or numpy array 
        if isinstance(f, str):
            self.images = np.genfromtxt(f, delimiter=',', dtype=str) 
        elif isinstance(f, np.ndarray):
            self.images = f
        else:
            raise ValueError("incorrect file format for input f")
        self.labels = np.array([label_extractor(i) for i in self.images])
        self.transform = transform
        if isinstance(f, str):
            print(f"Loaded original data of {len(self.labels)} images from {f}")
        if isinstance(f, np.ndarray):
            print(f"Loaded original data of {len(self.labels)} images")
            
        total = len(self.labels)
        if training_images:
            print("Assuming that indexes of stored images are pre-sorted")
            print(f"Selecting first {training_images} images from total {total} available images")
            self.images = self.images[:training_images]
            self.labels = self.labels[:training_images]
        
        if classes:
            print(f"Loading only class {classes} images.")
            valid_indices = []
            new_labels = []
            for i, index in enumerate(classes):
                temp = np.where(self.labels==index)[0]
                valid_indices += list(temp)
                new_labels += [i] * len(temp)
            
            self.images = self.images[valid_indices]
            self.labels = np.array(new_labels)
            # lets just shuffle them to ease our conscience, in case we miss shuffling in dataloader
            indices = np.random.permutation(np.arange(len(self.labels)))
            self.images, self.labels = self.images[indices], self.labels[indices]
            print(f"Carved dataset for only {classes} classes comprising {len(self.labels)} images")
        
    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        path = self.images[idx]
        # Using Image.open(path) is extrememly slow. Go with this way as recommended in Torchvision code itself.
        with open(path, 'rb') as f:
            img = Image.open(f)
            img = img.convert('RGB')
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label


def cifar2(datadir="./datasets/", batch_size=128, mode="org", size=32, normalize=False, norm_layer=None, workers=4, distributed=False, classes=None, **kwargs):
    def update_list(vals, indices, c):
        for i in indices:
            vals[i] = c
        return vals

    # mode: base | org
    if norm_layer is None:
        if normalize:
            norm_layer = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.202, 0.199, 0.201])
        else:
            norm_layer = transforms.Normalize(mean=[0., 0., 0.], std=[1., 1., 1.])
    
    trtrain = [transforms.RandomCrop(size, padding=4), transforms.RandomHorizontalFlip(), 
          transforms.ToTensor(), norm_layer]
    if size != 32:
        trtrain = [transforms.Resize(size)] + trtrain
    transform_train = transforms.Compose(trtrain)
    trval = [transforms.ToTensor(), norm_layer]
    if size != 32:
        trval = [transforms.Resize(size)] + trval
    transform_test = transforms.Compose(trval)

    if mode == "org":
        None
    elif mode == "base":
        transform_train = transform_test
    else:
        raise ValueError(f"{mode} mode not supported")
        
    trainset = datasets.CIFAR10(
            root=os.path.join(datadir, "cifar10"),
            train=True,
            download=True,
            transform=transform_train
        )
    testset = datasets.CIFAR10(
            root=os.path.join(datadir, "cifar10"),
            train=False,
            download=True,
            transform=transform_test,
        )
    
    
    train_sampler, test_sampler = None, None
    if classes:
        assert len(classes) == 2
     
        indices_train_a, indices_train_b = np.where(np.array(trainset.targets)==classes[0])[0], np.where(np.array(trainset.targets)==classes[1])[0]
        indices_train = np.concatenate([indices_train_a, indices_train_b])
        trainset.targets = update_list(trainset.targets, indices_train_a, 0)
        trainset.targets = update_list(trainset.targets, indices_train_b, 1)

        indices_test_a, indices_test_b = np.where(np.array(testset.targets)==classes[0])[0], np.where(np.array(testset.targets)==classes[1])[0]
        indices_test = np.concatenate([indices_test_a, indices_test_b])
        testset.targets = update_list(testset.targets, indices_test_a, 0)
        testset.targets = update_list(testset.targets, indices_test_b, 1)
        
        train_sampler = torch.utils.data.SubsetRandomSampler(indices_train)
        test_sampler = torch.utils.data.SubsetRandomSampler(indices_test)

    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=False, sampler=train_sampler, num_workers=workers, pin_memory=True)
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=True, sampler=test_sampler, num_workers=workers, pin_memory=True)
    
    return train_loader, train_sampler, test_loader, test_sampler, None, None, transform_train


def diffusion_cifar10(datadir="./synthetic_dataset/diffusion/", batch_size=128, mode="org", size=32, normalize=False, norm_layer=None, workers=4, distributed=False, classes=None, training_images=None, syn_labels=None, **kwargs):
    # mode: base | org
    if norm_layer is None:
        if normalize:
            norm_layer = transforms.Normalize(mean=[0., 0., 0.], std=[0., 0., 0.])
        else:
            norm_layer = transforms.Normalize(mean=[0., 0., 0.], std=[1., 1., 1.])
    
    trtrain = [transforms.RandomCrop(size, padding=4), transforms.RandomHorizontalFlip(), 
          transforms.ToTensor(), norm_layer]
    if size != 32:
        trtrain = [transforms.Resize(size)] + trtrain
    transform_train = transforms.Compose(trtrain)
    trval = [transforms.ToTensor(), norm_layer]
    if size != 32:
        trval = [transforms.Resize(size)] + trval
    transform_test = transforms.Compose(trval)

    if mode == "org":
        None
    elif mode == "base":
        transform_train = transform_test
    else:
        raise ValueError(f"{mode} mode not supported")
    
    datadir = "./synthetic_dataset/diffusion/"
    print(f"Discarding args.datadir and loading data from fixed source: {datadir}. Using {workers} workers.")
    label_extractor = lambda x: int(x.split("/")[-2])
    
    if syn_labels == "bit":
        trainset = custom_loader_from_image_list(os.path.join(datadir, "train_split_seed_0.txt"), label_extractor, classes, training_images, transform_train)
        testset = custom_loader_from_image_list(os.path.join(datadir, "test_split_seed_0.txt"), label_extractor, classes, None, transform_test)
    elif syn_labels == "lanet":
        trainset = custom_loader_from_image_list(os.path.join(datadir, "ddpm_labels/train_split_seed_0_lanet_all.txt"), label_extractor, classes, training_images, transform_train)
        testset = custom_loader_from_image_list(os.path.join(datadir, "ddpm_labels/test_split_seed_0_lanet_all.txt"), label_extractor, classes, None, transform_test)
    elif syn_labels == "lanet_bit":
        trainset = custom_loader_from_image_list(os.path.join(datadir, "ddpm_labels/train_split_seed_0_lanet_bit.txt"), label_extractor, classes, training_images, transform_train)
        testset = custom_loader_from_image_list(os.path.join(datadir, "ddpm_labels/test_split_seed_0_lanet_bit.txt"), label_extractor, classes, None, transform_test)
    else:
        raise ValueError(f"Syn_labels {syn_labels} is not valid")
    
    train_sampler, test_sampler = None, None
    if distributed:
        print("Using DistributedSampler")
        train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(testset)
        
    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=workers, pin_memory=True)
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=True, sampler=test_sampler, num_workers=workers, pin_memory=True)
    
    return train_loader, train_sampler, test_loader, test_sampler, None, None, transform_train


def styleganC_cifar10(datadir="./stylegan_ada/cifar10/", batch_size=128, mode="org", size=32, normalize=False, norm_layer=None, workers=4, distributed=False, classes=None, training_images=None, **kwargs):
    # mode: base | org
    if norm_layer is None:
        if normalize:
            norm_layer = transforms.Normalize(mean=[0., 0., 0.], std=[0., 0., 0.])
        else:
            norm_layer = transforms.Normalize(mean=[0., 0., 0.], std=[1., 1., 1.])
    
    trtrain = [transforms.RandomCrop(size, padding=4), transforms.RandomHorizontalFlip(), 
          transforms.ToTensor(), norm_layer]
    if size != 32:
        trtrain = [transforms.Resize(size)] + trtrain
    transform_train = transforms.Compose(trtrain)
    trval = [transforms.ToTensor(), norm_layer]
    if size != 32:
        trval = [transforms.Resize(size)] + trval
    transform_test = transforms.Compose(trval)

    if mode == "org":
        None
    elif mode == "base":
        transform_train = transform_test
    else:
        raise ValueError(f"{mode} mode not supported")
    
    datadir = "./stylegan_ada/cifar10/"
    print(f"Discarding args.datadir and loading data from fixed source: {datadir}. Using {workers} workers.")
    label_extractor = lambda x: int(x.split("/")[-2].split("_")[-1])
    
    trainset = custom_loader_from_image_list(os.path.join(datadir, "conditional_train_split_seed_0.txt"), label_extractor, classes, training_images, transform_train)
    testset = custom_loader_from_image_list(os.path.join(datadir, "conditional_test_split_seed_0.txt"), label_extractor, classes, None, transform_test)
    
    train_sampler, test_sampler = None, None
    if distributed:
        print("Using DistributedSampler")
        train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(testset)
        
    train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=workers, pin_memory=True)
    test_loader = DataLoader(testset, batch_size=batch_size, shuffle=True, sampler=test_sampler, num_workers=workers, pin_memory=True)
    
    return train_loader, train_sampler, test_loader, test_sampler, None, None, transform_train




